import torch
import torch.nn as nn
import sys
import os
parent_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_folder)
from Record.file_management import read_obj_dumps, strip_instance
import argparse
from Vae.data_utils.data_util import _CustomDataParallel
from Vae.models.cnn_vae import CNNVAE
import numpy as np
import PIL.Image as Image

from Vae.data_utils.vae_dataset import ImageDataset
import glob
from tqdm import tqdm
import pickle

def load_encodings(pth):
    # file_paths = glob.glob(os.path.join(pth, "encodings*.pkl"))
    # file_paths.sort()
    file_paths = [pth]
    encodings = list()
    for fp in file_paths:
        with open(fp, 'rb') as f:
            vals = pickle.load(f)
            encodings += vals
    return encodings


def get_pixel_xy_from_image(image):
    xs, ys = np.where(image != 0)[:2]
    mean = np.array([np.mean(xs), np.mean(ys)])
    if len(xs) == 0 or len(ys) == 0: # ys should also be 0 in this case
        return np.ones(2) * -10 # use -10 pixel location to denote missing
    return mean


def get_state_from_images(images):  # images are ordered from newest to oldest
    if isinstance(images, torch.Tensor):
        images = images.cpu().numpy()
    positions = [get_pixel_xy_from_image(image) for image in images]
    velocities = [positions[i] - positions[i + 1] for i in range(len(positions) - 1)]
    velocities = np.mean(velocities, axis=0)
    state = np.concatenate([positions[0], velocities])
    return state

def record_encodings_and_mask_state(model, obj_data, dataset, frame_stack, save_rollouts):
    device='cuda'
    model.eval()
    encodings_and_mask_states = [{} for _ in range(len(obj_data))]
    encodings = [{} for _ in range(len(obj_data))] # list of dictionaries: TIME STEP = name: encodin
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=2048, shuffle=False, num_workers=4, drop_last=False) # reduced number of workers because of memory issues

    with torch.no_grad():
        for data in tqdm(data_loader, desc="Encoding"):
            x, frame_idxs, obj_names = data['image'], data['frame_number'], data['obj_name']
            x = x.to(device)
            x_hat, mean, log_var = model(x)
            # mean = model.forward_fit_linear(x)

            B, f, h, w = x.shape
            for i in range(B):
                frame_idx = frame_idxs[i].item()
                obj_name = obj_names[i]
                obj_encoding  = mean[i].detach().cpu().numpy()
                obj_mask_state = get_state_from_images(x[i].reshape(frame_stack, 3, h, w).permute(0, 2, 3, 1))
                encodings[frame_idx][obj_name] = obj_encoding
                encodings_and_mask_states[frame_idx][obj_name] = np.concatenate([obj_mask_state, obj_encoding])
                
    
    non_empty_encodings = sum([len(x) > 0 for x in encodings])
    print("nonempty encodings: ", non_empty_encodings)

    # Save encodings
    with open(os.path.join(save_rollouts, 'encodings_vae.pkl'), 'wb') as f:
        pickle.dump(encodings, f)
    with open(os.path.join(save_rollouts, 'encodings_vae_mask_state.pkl'), 'wb') as f:
        pickle.dump(encodings_and_mask_states, f)